{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca0ddf27-b42d-4004-9f0c-4f43e55a245b",
   "metadata": {},
   "source": [
    "### `torchdata.nodes` Basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0ff94f5-fc60-4a7e-a3e5-0d7a4bfa2ddd",
   "metadata": {},
   "source": [
    "All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:\n",
    "```Python\n",
    "class BaseNode(Iterator[T]):\n",
    "    def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:\n",
    "        \"\"\"Resets the node to its initial state or a specified state.\"\"\"\n",
    "        ...\n",
    "    def __next__(self) -> T:\n",
    "        \"\"\"Returns the next value in the sequence.\"\"\"\n",
    "        ...\n",
    "    def get_state(self) -> Dict[str, Any]:\n",
    "        \"\"\"Returns a dictionary representing the current state of the node.\"\"\"\n",
    "        ...\n",
    "```\n",
    "This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5542e957-f40c-4624-9e3d-b4130d1fd03a",
   "metadata": {},
   "source": [
    "Let's see the functionalities of `torchdata.nodes` through the help of a very simple example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5168248-7c08-419d-a9a5-f831bd7b46e7",
   "metadata": {},
   "source": [
    "#### IterableWrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f42819f7-1019-48ca-ad23-00b38c72d63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdata.nodes import IterableWrapper\n",
    "# This Wrapper converts any Iterable in to a BaseNode.\n",
    "\n",
    "dataset = range(10) # creating a very simple dataset, and then converting it into a BaseNode\n",
    "source = IterableWrapper(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8bb313fe-5cad-4c5f-b7f3-4bd27b0645a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "# Let's take a look at the items in the node\n",
    "for item in source:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "247845d5-ef4b-4bc3-92c3-93c55972578a",
   "metadata": {},
   "source": [
    "#### Integrating with torch.data Dataloaders and Samplers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fb43c8-16c0-4c87-bb3a-ff6a0e126764",
   "metadata": {},
   "source": [
    "We can also use `torch.data.utils` style dataloaders and samplers, and then wrap them into nodes.\n",
    "Please refer to this [migration guide](https://pytorch.org/data/main/migrate_to_nodes_from_utils.html) to migrate from torch.utils.data to torchdata.nodes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a14e03a0-29f7-4824-8860-75f1f2a91533",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n",
      "9\n",
      "2\n",
      "1\n",
      "7\n",
      "5\n",
      "4\n",
      "3\n",
      "8\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import MapStyleWrapper\n",
    "from torch.utils.data import RandomSampler\n",
    "\n",
    "sampler = RandomSampler(dataset)\n",
    "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
    "\n",
    "for item in node:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73f6e1ce-6e33-49cd-b0bb-c49fee0e7cd6",
   "metadata": {},
   "source": [
    "#### Map"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a5246cc-6f40-4942-b514-df76c3946eee",
   "metadata": {},
   "source": [
    "We can use the Mapper class, to apply a transformation defined using the `map_fn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "71a07a13-5d6f-46d4-b316-102698eb13c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "4\n",
      "9\n",
      "16\n",
      "25\n",
      "36\n",
      "49\n",
      "64\n",
      "81\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import Mapper\n",
    "node = Mapper(source, map_fn = lambda x : x**2)\n",
    "for item in node:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0701898c-6c72-4065-aee9-23a6b516876d",
   "metadata": {},
   "source": [
    "It can also be executed in parallel, using the multi threading/processing approaches, depending on the defined `method`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0771661d-f0dc-4d35-a5af-abab84d7efd3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "4\n",
      "9\n",
      "16\n",
      "25\n",
      "36\n",
      "49\n",
      "64\n",
      "81\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import ParallelMapper\n",
    "mapper = ParallelMapper(source, map_fn = lambda x : x**2, num_workers =2, method = \"thread\")\n",
    "for item in mapper:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b517ebcf-9f92-43bf-b332-3875933ea244",
   "metadata": {},
   "source": [
    "#### Batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6156bcb4-891f-484a-bf06-cbf39e282175",
   "metadata": {},
   "source": [
    "A BaseNode can be passed into a Batcher, to get batches of size `batch_size`.\n",
    "By default, `drop_last` is True, meaning if the last batch has a size smaller than the `batch_size`, it is not produced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dc40c637-90ed-4c09-9961-3b0d488fc6d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3]\n",
      "[4, 5, 6, 7]\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import Batcher\n",
    "batcher = Batcher(source, batch_size = 4)\n",
    "for batch in batcher:\n",
    "    print(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5487e01d-6245-4940-a0df-c32aee0849cf",
   "metadata": {},
   "source": [
    "We can make `drop_last = False` to produce the last batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "092d8e5b-aeb2-4de1-9b13-8b0dc0af31ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3]\n",
      "[4, 5, 6, 7]\n",
      "[8, 9]\n"
     ]
    }
   ],
   "source": [
    "batcher = Batcher(source, batch_size = 4, drop_last = False)\n",
    "for batch in batcher:\n",
    "    print(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e11543b9-c2dd-4861-af24-0eb3a1f3ddb9",
   "metadata": {},
   "source": [
    "If we try to use this batcher over multiple epochs, we will need to reset it after every epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d89e36a9-cc9c-4b55-acba-dbcaf8d53407",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch = 0  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
      "Epoch = 1  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n"
     ]
    }
   ],
   "source": [
    "batcher = Batcher(source, batch_size = 10)\n",
    "num_epochs = 2\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for batch in batcher:\n",
    "        print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")\n",
    "    batcher.reset()\n",
    "    \n",
    "# This is one extra step than traditional dataloader, we can actually wrap the batcher in a Loader to skip that\n",
    "# Let's look at Loader in the next cell"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2232823-0927-432b-9171-fb2bf38a7205",
   "metadata": {},
   "source": [
    "#### Loader\n",
    "As you can see, we get a batch in every epoch, without even needing to reset the loader!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4f8a1830-49a2-4dfd-b17b-acaca7ac2e98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch = 0  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
      "Epoch = 1  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import Loader\n",
    "batcher = Batcher(source, batch_size = 10)\n",
    "loader = Loader(batcher)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for batch in loader:\n",
    "        print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "860142c7-a39b-441b-b2b2-1cbd81a7a0f3",
   "metadata": {},
   "source": [
    "#### SamplerWrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18134e7d-d0e8-4096-a674-23e715fa16a3",
   "metadata": {},
   "source": [
    "As mentioned earlier, we can use `torch.data.utils` samplers using `MapStyleWrapper`.\n",
    "Alternatively, we can employ the `SamplerWrapper`, which converts a `Sampler` into a `BaseNode`. `SamplerWrapper` differs from `IterableWrapper` because it will track the number of epochs, and call the sampler's `set_epoch` method if it is implemented."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f794c4d9-b3ed-4b7b-8c85-131d495e6605",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch = 0  Batch = [2, 9, 0, 5, 8, 6, 7, 3, 1, 4]\n",
      "Epoch = 1  Batch = [2, 3, 6, 8, 5, 0, 1, 9, 7, 4]\n"
     ]
    }
   ],
   "source": [
    "from torchdata.nodes import SamplerWrapper\n",
    "\n",
    "sampler = RandomSampler(dataset)\n",
    "node = SamplerWrapper(sampler)\n",
    "batcher = Batcher(node, batch_size = 10)\n",
    "loader = Loader(batcher)\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    for batch in loader:\n",
    "        print(f\"Epoch = {node.epoch}\", f\" Batch = {batch}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e956793-df2d-4d0f-94b0-daa7fb860bf5",
   "metadata": {},
   "source": [
    "`torchadata` nodes are composable, thus, many BaseNodes type nodes can be chained together for desired transformations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1a6291-d688-45eb-bf2e-56153419513e",
   "metadata": {},
   "source": [
    "#### Chaining multiple operations together\n",
    "\n",
    "Base nodes are iterators, and are designed to be chained together to create more complex dataloading graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cc4bcd30-9a76-4a1e-aa38-d7db8b9a8c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampler = RandomSampler(dataset)\n",
    "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
    "node = Mapper(node, map_fn = lambda x : x**3)\n",
    "node = Batcher(node, batch_size = 4, drop_last = False)\n",
    "loader = Loader(node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "798f64af-102d-4b24-8327-b47d4d0fc4f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[64, 216, 27, 729]\n",
      "[0, 8, 343, 1]\n",
      "[512, 125]\n"
     ]
    }
   ],
   "source": [
    "for batch in loader:\n",
    "    print(batch)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
